import jax
import jax.numpy as jnp
import numpy as np
import flax.linen as nn
import matplotlib.pyplot as plt
from flax.linen.activation import softmax
from einops import rearrange
from jax import random, value_and_grad
import optax
import sys
import wandb
import argparse
import os
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from itertools import chain
from torch.utils.data import DataLoader

from flax.training import orbax_utils
from flax.training import checkpoints, train_state
import orbax.checkpoint

# Argument parser setup
parser = argparse.ArgumentParser(description='Tests of multiple-layer-per-block models', formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('--lr', default=0.01, type=float, help='Learning rate')
parser.add_argument('--mom', default=0.0, type=float, help='Momentum')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument('--max_len', type=int, default=256, help='Maximum sequence length')
parser.add_argument('--arch', type=str, default='LLM', help='Model architecture')
parser.add_argument('--dataset', type=str, default='C4', help='Dataset')
parser.add_argument('--optimizer', default='adam', help='Optimizer')
parser.add_argument('--width', type=int, default=32, help='Model width')
parser.add_argument('--heads', type=int, default=16, help='Number of heads in attention')
parser.add_argument('--depth', type=int, default=12, help='Model depth')
parser.add_argument('--beta', type=float, default=12.0, help='Scaling factor for the residual branch')
parser.add_argument('--gamma_zero', type=float, default=0.25, help='Controls the amount of feature learning')
parser.add_argument('--scale_exp', type=float, default=1.0, help='Scaling exponent')
parser.add_argument('--steps', type=int, default=10000, help='Number of training steps')
parser.add_argument('--save_model', action='store_true', help='Save the model')

args = parser.parse_args()

# Directory for saving models
save_dir = '/path/to/save/dir'

# Define parameter sweeps
depths = [args.depth] if args.depth != -1 else [4, 8, 16, 32, 64]
widths = [args.width] if args.width != -1 else [8, 16, 32, 64, 128, 256]
heads = [args.heads] if args.heads != -1 else [8, 16, 32, 64, 128, 256]
lrs = [args.lr] if args.lr != -1 else np.logspace(-3, 0, 10)
adam = (args.optimizer == "adam")

# Data loading and tokenization
try:
    ds = load_dataset("allenai/c4", 'en', streaming=True)['train']
    shuff_ds = ds.shuffle(seed=0, buffer_size=10000)
except Exception as e:
    print(f"Error loading dataset: {e}")
    sys.exit(1)

tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
VOCAB_SIZE = len(tokenizer)

# Function to encode the text
def encode(examples):
    """Tokenizes the text examples."""
    try:
        return tokenizer(examples['text'])
    except Exception as e:
        print(f"Error during tokenization: {e}")
        return None

# Function to group texts to the specified maximum length
def group_texts(examples):
    """Groups texts to the specified maximum length."""
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= args.max_len:
        total_length = (total_length // args.max_len) * args.max_len
    result = {
        k: [t[i: i + args.max_len] for i in range(0, total_length, args.max_len)]
        for k, t in concatenated_examples.items()
    }
    return result

try:
    dataset = shuff_ds.map(encode, batched=True, remove_columns=["timestamp", "url"])
    dataset_grouped = dataset.map(group_texts, batched=True)
except Exception as e:
    print(f"Error processing dataset: {e}")
    sys.exit(1)

collate_fn = lambda x: [xi["input_ids"] for xi in x]
dataloader = DataLoader(dataset_grouped, batch_size=args.batch_size, collate_fn=collate_fn)

# Define custom LayerNorm class
class LN_Fixed(nn.Module):
    eps: jnp.float32 = 1.0e-6

    @nn.compact
    def __call__(self, x):
        """Deterministic LayerNorm implementation for text data."""
        features = x.shape[-1]
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        out = (x - mean) / jnp.sqrt(var + self.eps)
        return out

# Define model components
class Causal_Attention(nn.Module):
    """Causal attention mechanism for the transformer."""
    scale_exp: jnp.float32
    dim: int
    heads: int
    qk_ln: bool = True

    def setup(self):
        """Setup method to initialize layers."""
        self.c = 1.5 - self.scale_exp  # Scaling factor
        kif_qk = nn.initializers.normal(stddev=self.dim**(self.c - 0.5))  # Initializer for QK
        kif_v = nn.initializers.normal(stddev=1.0)  # Initializer for V
        self.qk_layer = nn.Dense(features=2 * self.heads * self.dim, kernel_init=kif_qk, use_bias=False)  # QK layer
        self.v_layer = nn.Dense(features=self.heads * self.dim, kernel_init=kif_v, use_bias=False)  # V layer
        self.out_layer = nn.Dense(features=self.heads * self.dim, kernel_init=kif_v, use_bias=False)  # Output layer
        self.q_norm = LN_Fixed()  # Custom LayerNorm for Q
        self.k_norm = LN_Fixed() # Custom LayerNorm for K

    def __call__(self, inputs):
        """Forward pass for the attention mechanism."""
        qk = self.qk_layer(inputs) / self.heads**0.5 / self.dim**(self.c)  # Shape: (batch, loc, 2*heads*dim)
        qk = rearrange(qk, 'b l (h d) -> b h l d', h=self.heads)  # Shape: (batch, heads, loc, dim)
        q, k = jnp.split(qk, 2, axis=-1)  # Split into Q and K, each shape: (batch, heads, loc, dim)
        if self.qk_ln:
            q = self.q_norm(q)  # Apply LayerNorm to Q
            k = self.k_norm(k)  # Apply LayerNorm to K
        v = self.v_layer(inputs) / jnp.sqrt(inputs.shape[-1])  # Shape: (batch, loc, heads*dim)
        v = rearrange(v, 'b l (h d) -> b h l d', h=self.heads)  # Shape: (batch, heads, loc, dim)
        A = 1.0 / self.dim**self.scale_exp * jnp.einsum('ijkl,ijml->ijkm', q, k)  # Attention matrix: (batch, heads, loc, loc)
        exp_A = jnp.einsum('ijkl,kl->ijkl', jnp.exp(A), jnp.tril(jnp.ones((v.shape[2], v.shape[2]))))  # Apply causal mask
        phi_A = exp_A / exp_A.sum(axis=-1, keepdims=True)  # Normalize attention scores
        out = jnp.einsum('ijkl,ijlm->ijkm', phi_A, v)  # Weighted sum of values: (batch, heads, loc, dim)
        out = rearrange(out, 'b h l d -> b l (h d)')  # Shape: (batch, loc, heads*dim)
        out = self.out_layer(out) / jnp.sqrt(out.shape[-1])  # Output projection
        return out

class MLP_Block(nn.Module):
    """Multi-layer perceptron block."""
    features: int

    @nn.compact
    def __call__(self, x):
        """Forward pass for the MLP block."""
        N = self.features
        kif = nn.initializers.normal(stddev=1.0)  # Initializer
        h = nn.Dense(features=N, kernel_init=kif, use_bias=False)(x) / jnp.sqrt(N)  # Dense layer with scaling
        h = nn.gelu(h)  # GELU activation
        h = nn.Dense(features=N, kernel_init=kif, use_bias=False)(h) / jnp.sqrt(N)  # Another dense layer with scaling
        return h

class PositionalEncoding(nn.Module):
    """Positional encoding for the transformer."""
    d_model: int  # Hidden dimensionality of the input
    scale: jnp.float32  # Initial standard deviation of entries
    max_len: int = args.max_len  # Maximum length of a sequence to expect

    def setup(self):
        """Setup method to initialize the positional encoding."""
        self.pos_embedding = self.param('pos_embedding', nn.initializers.normal(stddev=self.scale), (1, 1 + self.max_len, self.d_model))

    def __call__(self, x, train=True):
        """Forward pass for the positional encoding."""
        B, T, _ = x.shape
        x = x + self.pos_embedding[:, :T] / self.scale  # Add positional encoding
        return x

class Transformer(nn.Module):
    """Transformer model."""
    dim: int
    heads: int
    depth: int
    scale_exp: jnp.float32
    adam_scale: int
    beta: jnp.float32

    @nn.compact
    def __call__(self, x, train=True):
        """Forward pass for the transformer."""
        N = self.heads * self.dim  # Total dimension after concatenating heads
        L = self.depth  # Depth of the transformer

        # Initializers
        kif_first = nn.initializers.normal(stddev=N**(-0.5 * self.adam_scale) * (L / self.beta)**(0.5 * (1 - self.adam_scale)))
        kif0 = nn.initializers.normal(stddev=0.0)
        kif = nn.initializers.normal(stddev=1.0)
        kif_last = nn.initializers.normal(stddev=(L / self.beta)**(0.5 * (1 - self.adam_scale)) * N**(-0.5 * self.adam_scale))

        # Embedding layer
        x = nn.Embed(VOCAB_SIZE, N, embedding_init=kif_first)(x)
        x = (L / self.beta)**(-0.5 * (1 - self.adam_scale)) * N**(0.5 * self.adam_scale) * x

        # Positional encoding
        x = PositionalEncoding(d_model=N, scale=N**(-0.5 * self.adam_scale) * (L / self.beta)**(0.5 * (1 - self.adam_scale)))(x)

        # Transformer blocks
        for l in range(self.depth):
            h = LN_Fixed()(x)  # Custom LayerNorm
            x = x + self.beta / L * Causal_Attention(dim=self.dim, scale_exp=self.scale_exp, heads=self.heads)(h)
            h = LN_Fixed()(x)  # Custom LayerNorm
            x = x + self.beta / L * MLP_Block(features=N)(h)
        
        # Final LayerNorm and output layer
        x = LN_Fixed()(x)  # Custom LayerNorm
        x = (L / self.beta)**(-0.5 * (1 - self.adam_scale)) * nn.Dense(features=VOCAB_SIZE, use_bias=True, kernel_init=kif0)(x) / N**(1.0 - 0.5 * self.adam_scale)
        return x

# Function to create the train state
def create_train_state(rng, model, learning_rate):
    """Creates initial `TrainState`."""
    dummy_input = jnp.ones([1, args.max_len], jnp.int32)  # Dummy input for initialization
    variables = model.init(rng, dummy_input)
    params = variables['params']  # Extract the parameters
    tx = optax.adamw(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Training function
def train_model(param_args, opt_args, data=None, adam=True):
    """Train the transformer model with the given parameters."""
    dim, heads, depth, scale_exp, beta = param_args
    lr, gamma, T = opt_args

    if adam:
        adam_scale = 1
        schedule = optax.warmup_cosine_decay_schedule(init_value=0.0, peak_value=lr / jnp.sqrt(heads * dim), warmup_steps=100, decay_steps=T, end_value=0.0)
        optimizer = optax.adamw(schedule, eps=1e-20, weight_decay=0.0)
    else:
        adam_scale = 0
        optimizer = optax.sgd(gamma**2 * heads * dim * lr)

    model = Transformer(dim, heads, depth, scale_exp=scale_exp, adam_scale=adam_scale, beta=beta)
    rng = random.PRNGKey(0)
    state = create_train_state(rng, model, lr)
    
    # Updated loss function to use correct parameter passing
    loss_fn = jax.jit(lambda params, Xb, yb: optax.softmax_cross_entropy_with_integer_labels(logits=model.apply({'params': params}, Xb) / gamma, labels=yb).mean())
    val_grad_fn = jax.jit(value_and_grad(loss_fn))

    run_loss = 0.0
    losses = []
    steps = 0

    for t, batch in enumerate(dataloader):
        if np.amin(np.array([len(b) for b in batch])) == args.max_len:
            batch = jnp.array(batch, dtype=jnp.int32)
            loss, grads = val_grad_fn(state.params, batch[:, :-1], batch[:, 1:])
            state = state.apply_gradients(grads=grads)
            run_loss = loss
            sys.stdout.write(f'\r loss = {run_loss}')
            wandb.log({'loss': run_loss})
            losses.append(run_loss)
            steps += 1
            if args.save_model:
                if steps == 0 or steps == 100 or steps == 1000 or steps == 5000:
                    orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()  # Initialize Orbax checkpointer
                    ckpt = {'model': state, 'config': args, 'losses': losses}
                    save_args = orbax_utils.save_args_from_target(ckpt)
                    orbax_checkpointer.save(os.path.join(save_path, f'-ckpt_{steps}'), ckpt, save_args=save_args)

            if steps > T:
                break

    return losses, state.params, model


# Function to generate a run name based on the arguments
def get_run_name(args):
    """Generate a run name based on the arguments."""
    return f"model_{args.arch}/dataset_{args.dataset}/optimizer_{args.optimizer}/lr_{args.lr:.4f}/batch_size_{args.batch_size}/steps_{args.steps}/width_{args.width}/heads_{args.heads}/depth_{args.depth}/scale_exp_{args.scale_exp}/beta_{args.beta}/gamma_zero_{args.gamma_zero}"

# Run experiments with different parameter settings
for dim in widths:
    for depth in depths:
        for head in heads:
            for lr in lrs:
                args.width = dim
                args.depth = depth
                args.lr = lr
                args.heads = head

                run_name = get_run_name(args)
                save_path = os.path.join(save_dir, run_name.replace("/", "-"))
                opt_args = (args.lr, args.gamma_zero, args.steps)
                param_args = (dim, head, depth, args.scale_exp, args.beta)

                wandb.init(project="Infinite Head Limit", config=args.__dict__)
                wandb.run.name = run_name

                try:
                    losses, params, model = train_model(param_args, opt_args, adam=adam)
                    np.save(save_path, losses)

                    if args.save_model:
                        state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optax.adam(args.lr))
                        ckpt = {'model': state, 'config': args, 'losses': losses}
                        save_args = orbax_utils.save_args_from_target(ckpt)
                        orbax_checkpointer.save(save_path + '_ckpts', ckpt, save_args=save_args)

                except Exception as e:
                    print(f"Error during training: {e}")
                finally:
                    wandb.finish()
